--- ../../plenoxels/opt/util/nerf_dataset.py	2024-02-13 12:06:07.930166019 -0500
+++ ../google/plenoxels/opt/util/nerf_dataset.py	2024-02-13 10:46:24.448645161 -0500
@@ -1,5 +1,10 @@
+import os
+if 'DEBUG_SESSION' in os.environ and os.environ['DEBUG_SESSION'] != 'False':
+    # import debug tools
+    from my_python_utils.common_utils import *
+
 # Standard NeRF Blender dataset loader
-from .util import Rays, Intrin, select_or_shuffle_rays
+from .util import Rays, Intrin, select_or_shuffle_rays, construct_mask_mae
 from .dataset_base import DatasetBase
 import torch
 import torch.nn.functional as F
@@ -10,7 +15,8 @@
 import cv2
 import json
 import numpy as np
-
+import os
+import glob
 
 class NeRFDataset(DatasetBase):
     """
@@ -27,6 +33,7 @@
     n_images: int
     rays: Optional[Rays]
     split: str
+    mask: bool
 
     def __init__(
         self,
@@ -40,9 +47,13 @@
         permutation: bool = True,
         white_bkgd: bool = True,
         n_images = None,
+        opencv_mode = True,
+        rays_at_center_pixel=True,
+        load_depth_if_exists=False,
+        masking_percentage=0.0,
         **kwargs
     ):
-        super().__init__()
+        super().__init__(opencv_mode=opencv_mode, rays_at_center_pixel=rays_at_center_pixel)
         assert path.isdir(root), f"'{root}' is not a directory"
 
         if scene_scale is None:
@@ -54,6 +65,10 @@
         self.epoch_size = epoch_size
         all_c2w = []
         all_gt = []
+        all_invalid_pixels = []
+        
+        if load_depth_if_exists:
+            all_depths = []
 
         split_name = split if split != "test_train" else "train"
         data_path = path.join(root, split_name)
@@ -66,12 +81,14 @@
         j = json.load(open(data_json, "r"))
 
         # OpenGL -> OpenCV
-        cam_trans = torch.diag(torch.tensor([1, -1, -1, 1], dtype=torch.float32))
+        if opencv_mode:
+            cam_trans = torch.diag(torch.tensor([1, -1, -1, 1], dtype=torch.float32))
 
         for frame in tqdm(j["frames"]):
             fpath = path.join(data_path, path.basename(frame["file_path"]) + ".png")
             c2w = torch.tensor(frame["transform_matrix"], dtype=torch.float32)
-            c2w = c2w @ cam_trans  # To OpenCV
+            if opencv_mode:
+                c2w = c2w @ cam_trans  # To OpenCV
 
             im_gt = imageio.imread(fpath)
             if scale < 1.0:
@@ -81,6 +98,38 @@
 
             all_c2w.append(c2w)
             all_gt.append(torch.from_numpy(im_gt))
+            # for hardcoded masks, unused
+            if False:
+                invalid_pixels_path = fpath.replace(".png", "_invalid_pixels.png")
+                if os.path.exists(invalid_pixels_path):
+                    im_invalid_px = imageio.imread(invalid_pixels_path)
+                    if scale < 1.0:
+                        im_invalid_px = cv2.resize(im_invalid_px, (rsz_w, rsz_h), interpolation=cv2.INTER_NEAREST)
+
+                    all_invalid_pixels.append(torch.from_numpy(im_invalid_px[...,0] == 255))
+            elif masking_percentage > 0:
+                assert masking_percentage < 1.0, "masking_percentage must be in [0,1)"
+                assert im_gt.shape[0] == im_gt.shape[1], "masking only supported for square images"
+                active_pixels = construct_mask_mae(masking_percentage, im_gt.shape[:2], patch_size = im_gt.shape[1] // 20, random_seed=1337 + len(all_gt) - 1)
+                invalid_pixels = 1 - active_pixels
+                all_invalid_pixels.append(torch.from_numpy(invalid_pixels))
+                    
+            if load_depth_if_exists:
+                depth_path_wildcard = fpath.replace('.png', '') + '_depth_*'
+                files = glob.glob(depth_path_wildcard)
+                if len(files) == 1:
+                    disparity_px = imageio.imread(files[0])[...,0]
+                    depth_px = np.zeros_like(disparity_px, dtype='float32')
+                    depth_px[disparity_px != 0 ] = 1.0 / disparity_px[disparity_px != 0]
+                    all_depths.append(torch.from_numpy(depth_px))
+                else:
+                    print("Requested to load depths but depth file ({}) does not exist".format())
+                
+                
+        if len(all_invalid_pixels) > 0:
+            assert len(all_invalid_pixels) == len(all_gt)
+            self.invalid_pixels = torch.stack(all_invalid_pixels)
+
         focal = float(
             0.5 * all_gt[0].shape[1] / np.tan(0.5 * j["camera_angle_x"])
         )
@@ -88,6 +137,8 @@
         self.c2w[:, :3, 3] *= scene_scale
 
         self.gt = torch.stack(all_gt).float() / 255.0
+        if load_depth_if_exists:
+            self.depths = torch.stack(all_depths).float() / 255.0
         if self.gt.size(-1) == 4:
             if white_bkgd:
                 # Apply alpha channel
@@ -120,3 +171,18 @@
 
         self.should_use_background = False  # Give warning
 
+    # to visualize training/val as sorted video. could be nicer, but better than the original sorting
+    def get_ids_sorted_by_position(self):
+        all_poses = self.c2w
+
+        ids_to_match = list(range(len(all_poses)))
+        cur_id = ids_to_match.pop(0)
+
+        sorted_ids = [cur_id]
+        while len(ids_to_match) > 0:
+            pos_diffs = all_poses[cur_id][:3,3] - all_poses[ids_to_match][:,:3,3]
+            # get closest position in l2
+            cur_id = ids_to_match.pop(np.argmin(np.linalg.norm(pos_diffs, axis=-1)))
+            sorted_ids.append(cur_id)
+        
+        return sorted_ids
